# 给定一个根为 root 的二叉树，每个节点的深度是 该节点到根的最短距离 。
#  返回包含原始树中所有 最深节点 的 最小子树。
#  如果一个节点在 整个树 的任意节点之间具有最大的深度，则该节点是 最深的 。
#  一个节点的 子树 是该节点加上它的所有后代的集合。
#
#  示例 1：
# 输入：root = [3,5,1,6,2,0,8,null,null,7,4]
# 输出：[2,7,4]
# 解释：
# 我们返回值为 2 的节点，在图中用黄色标记。
# 在图中用蓝色标记的是树的最深的节点。
# 注意，节点 5、3 和 2 包含树中最深的节点，但节点 2 的子树最小，因此我们返回它。
#
#  示例 2：
# 输入：root = [1]
# 输出：[1]
# 解释：根节点是树中最深的节点。
#
#  示例 3：
# 输入：root = [0,1,3,null,2]
# 输出：[2]
# 解释：树中最深的节点为 2 ，有效子树为节点 2、1 和 0 的子树，但节点 2 的子树最小。
from com.example.tree.tree_node import TreeNode


class Solution:
    def subtreeWithAllDeepest(self, root: TreeNode) -> TreeNode:
        """
        如果当前节点是最深叶子节点的最近公共祖先，那么它的左右子树的高度一定是相等的，
        否则高度低的那个子树的叶子节点深度一定比另一个子树的叶子节点的深度小，因此不满足条件。
        所以只需要dfs遍历找到左右子树高度相等的根节点即出答案。
        :return:
        """
        def getDepth(node: TreeNode) -> int:  # 获取节点的深度
            return max(getDepth(node.left), getDepth(node.right)) + 1 if node else 0

        def dfs(node: TreeNode) -> TreeNode:
            if not node:
                return None
            leftDepth = getDepth(node.left)
            rightDepth = getDepth(node.right)
            if leftDepth == rightDepth:
                return node
            elif leftDepth > rightDepth:
                return dfs(node.left)
            elif leftDepth < rightDepth:
                return dfs(node.right)

        return dfs(root)


if __name__ == "__main__":
    #                 3
    #        5                   1
    #    6        2         0         8
    #          7     4
    root = TreeNode(3)
    root.left, root.right = TreeNode(5), TreeNode(1)
    root.left.left, root.left.right, root.right.left, root.right.right = TreeNode(6), TreeNode(2), TreeNode(0), TreeNode(8)
    root.left.right.left, root.left.right.right = TreeNode(7), TreeNode(4)
    #       0
    #  1          3
    #     2
    root = TreeNode(0)
    root.left, root.right = TreeNode(1), TreeNode(3)
    root.left.right = TreeNode(2)

    ans = Solution().subtreeWithAllDeepest(root)

    tmpList = []

    def travel(node: TreeNode) -> None:
        if not node:
            return
        tmpList.append(node.val)
        travel(node.left)
        travel(node.right)
    travel(ans)
    print(tmpList)
